"""
Unit tests for rerank document chunking functionality.

Tests the chunk_documents_for_rerank and aggregate_chunk_scores functions
in lightrag/rerank.py to ensure proper document splitting and score aggregation.
"""

import pytest
from unittest.mock import Mock, patch, AsyncMock
from lightrag.rerank import (
    chunk_documents_for_rerank,
    aggregate_chunk_scores,
    cohere_rerank,
)


class TestChunkDocumentsForRerank:
    """Test suite for chunk_documents_for_rerank function"""

    def test_no_chunking_needed_for_short_docs(self):
        """Documents shorter than max_tokens should not be chunked"""
        documents = [
            "Short doc 1",
            "Short doc 2",
            "Short doc 3",
        ]

        chunked_docs, doc_indices = chunk_documents_for_rerank(
            documents, max_tokens=100, overlap_tokens=10
        )

        # No chunking should occur
        assert len(chunked_docs) == 3
        assert chunked_docs == documents
        assert doc_indices == [0, 1, 2]

    def test_chunking_with_character_fallback(self):
        """Test chunking falls back to character-based when tokenizer unavailable"""
        # Create a very long document that exceeds character limit
        long_doc = "a" * 2000  # 2000 characters
        documents = [long_doc, "short doc"]

        with patch("lightrag.utils.TiktokenTokenizer", side_effect=ImportError):
            chunked_docs, doc_indices = chunk_documents_for_rerank(
                documents,
                max_tokens=100,  # 100 tokens = ~400 chars
                overlap_tokens=10,  # 10 tokens = ~40 chars
            )

        # First doc should be split into chunks, second doc stays whole
        assert len(chunked_docs) > 2  # At least one chunk from first doc + second doc
        assert chunked_docs[-1] == "short doc"  # Last chunk is the short doc
        # Verify doc_indices maps chunks to correct original document
        assert doc_indices[-1] == 1  # Last chunk maps to document 1

    def test_chunking_with_tiktoken_tokenizer(self):
        """Test chunking with actual tokenizer"""
        # Create document with known token count
        # Approximate: "word " = ~1 token, so 200 words ~ 200 tokens
        long_doc = " ".join([f"word{i}" for i in range(200)])
        documents = [long_doc, "short"]

        chunked_docs, doc_indices = chunk_documents_for_rerank(
            documents, max_tokens=50, overlap_tokens=10
        )

        # Long doc should be split, short doc should remain
        assert len(chunked_docs) > 2
        assert doc_indices[-1] == 1  # Last chunk is from second document

        # Verify overlapping chunks contain overlapping content
        if len(chunked_docs) > 2:
            # Check that consecutive chunks from same doc have some overlap
            for i in range(len(doc_indices) - 1):
                if doc_indices[i] == doc_indices[i + 1] == 0:
                    # Both chunks from first doc, should have overlap
                    chunk1_words = chunked_docs[i].split()
                    chunk2_words = chunked_docs[i + 1].split()
                    # At least one word should be common due to overlap
                    assert any(word in chunk2_words for word in chunk1_words[-5:])

    def test_empty_documents(self):
        """Test handling of empty document list"""
        documents = []
        chunked_docs, doc_indices = chunk_documents_for_rerank(documents)

        assert chunked_docs == []
        assert doc_indices == []

    def test_single_document_chunking(self):
        """Test chunking of a single long document"""
        # Create document with ~100 tokens
        long_doc = " ".join([f"token{i}" for i in range(100)])
        documents = [long_doc]

        chunked_docs, doc_indices = chunk_documents_for_rerank(
            documents, max_tokens=30, overlap_tokens=5
        )

        # Should create multiple chunks
        assert len(chunked_docs) > 1
        # All chunks should map to document 0
        assert all(idx == 0 for idx in doc_indices)


class TestAggregateChunkScores:
    """Test suite for aggregate_chunk_scores function"""

    def test_no_chunking_simple_aggregation(self):
        """Test aggregation when no chunking occurred (1:1 mapping)"""
        chunk_results = [
            {"index": 0, "relevance_score": 0.9},
            {"index": 1, "relevance_score": 0.7},
            {"index": 2, "relevance_score": 0.5},
        ]
        doc_indices = [0, 1, 2]  # 1:1 mapping
        num_original_docs = 3

        aggregated = aggregate_chunk_scores(
            chunk_results, doc_indices, num_original_docs, aggregation="max"
        )

        # Results should be sorted by score
        assert len(aggregated) == 3
        assert aggregated[0]["index"] == 0
        assert aggregated[0]["relevance_score"] == 0.9
        assert aggregated[1]["index"] == 1
        assert aggregated[1]["relevance_score"] == 0.7
        assert aggregated[2]["index"] == 2
        assert aggregated[2]["relevance_score"] == 0.5

    def test_max_aggregation_with_chunks(self):
        """Test max aggregation strategy with multiple chunks per document"""
        # 5 chunks: first 3 from doc 0, last 2 from doc 1
        chunk_results = [
            {"index": 0, "relevance_score": 0.5},
            {"index": 1, "relevance_score": 0.8},
            {"index": 2, "relevance_score": 0.6},
            {"index": 3, "relevance_score": 0.7},
            {"index": 4, "relevance_score": 0.4},
        ]
        doc_indices = [0, 0, 0, 1, 1]
        num_original_docs = 2

        aggregated = aggregate_chunk_scores(
            chunk_results, doc_indices, num_original_docs, aggregation="max"
        )

        # Should take max score for each document
        assert len(aggregated) == 2
        assert aggregated[0]["index"] == 0
        assert aggregated[0]["relevance_score"] == 0.8  # max of 0.5, 0.8, 0.6
        assert aggregated[1]["index"] == 1
        assert aggregated[1]["relevance_score"] == 0.7  # max of 0.7, 0.4

    def test_mean_aggregation_with_chunks(self):
        """Test mean aggregation strategy"""
        chunk_results = [
            {"index": 0, "relevance_score": 0.6},
            {"index": 1, "relevance_score": 0.8},
            {"index": 2, "relevance_score": 0.4},
        ]
        doc_indices = [0, 0, 1]  # First two chunks from doc 0, last from doc 1
        num_original_docs = 2

        aggregated = aggregate_chunk_scores(
            chunk_results, doc_indices, num_original_docs, aggregation="mean"
        )

        assert len(aggregated) == 2
        assert aggregated[0]["index"] == 0
        assert aggregated[0]["relevance_score"] == pytest.approx(0.7)  # (0.6 + 0.8) / 2
        assert aggregated[1]["index"] == 1
        assert aggregated[1]["relevance_score"] == 0.4

    def test_first_aggregation_with_chunks(self):
        """Test first aggregation strategy"""
        chunk_results = [
            {"index": 0, "relevance_score": 0.6},
            {"index": 1, "relevance_score": 0.8},
            {"index": 2, "relevance_score": 0.4},
        ]
        doc_indices = [0, 0, 1]
        num_original_docs = 2

        aggregated = aggregate_chunk_scores(
            chunk_results, doc_indices, num_original_docs, aggregation="first"
        )

        assert len(aggregated) == 2
        # First should use first score seen for each doc
        assert aggregated[0]["index"] == 0
        assert aggregated[0]["relevance_score"] == 0.6  # First score for doc 0
        assert aggregated[1]["index"] == 1
        assert aggregated[1]["relevance_score"] == 0.4

    def test_empty_chunk_results(self):
        """Test handling of empty results"""
        aggregated = aggregate_chunk_scores([], [], 3, aggregation="max")
        assert aggregated == []

    def test_documents_with_no_scores(self):
        """Test when some documents have no chunks/scores"""
        chunk_results = [
            {"index": 0, "relevance_score": 0.9},
            {"index": 1, "relevance_score": 0.7},
        ]
        doc_indices = [0, 0]  # Both chunks from document 0
        num_original_docs = 3  # But we have 3 documents total

        aggregated = aggregate_chunk_scores(
            chunk_results, doc_indices, num_original_docs, aggregation="max"
        )

        # Only doc 0 should appear in results
        assert len(aggregated) == 1
        assert aggregated[0]["index"] == 0

    def test_unknown_aggregation_strategy(self):
        """Test that unknown strategy falls back to max"""
        chunk_results = [
            {"index": 0, "relevance_score": 0.6},
            {"index": 1, "relevance_score": 0.8},
        ]
        doc_indices = [0, 0]
        num_original_docs = 1

        # Use invalid strategy
        aggregated = aggregate_chunk_scores(
            chunk_results, doc_indices, num_original_docs, aggregation="invalid"
        )

        # Should fall back to max
        assert aggregated[0]["relevance_score"] == 0.8


@pytest.mark.offline
class TestTopNWithChunking:
    """Tests for top_n behavior when chunking is enabled (Bug fix verification)"""

    @pytest.mark.asyncio
    async def test_top_n_limits_documents_not_chunks(self):
        """
        Test that top_n correctly limits documents (not chunks) when chunking is enabled.

        Bug scenario: 10 docs expand to 50 chunks. With old behavior, top_n=5 would
        return scores for only 5 chunks (possibly all from 1-2 docs). After aggregation,
        fewer than 5 documents would be returned.

        Fixed behavior: top_n=5 should return exactly 5 documents after aggregation.
        """
        # Setup: 5 documents, each producing multiple chunks when chunked
        # Using small max_tokens to force chunking
        long_docs = [" ".join([f"doc{i}_word{j}" for j in range(50)]) for i in range(5)]
        query = "test query"

        # First, determine how many chunks will be created by actual chunking
        _, doc_indices = chunk_documents_for_rerank(
            long_docs, max_tokens=50, overlap_tokens=10
        )
        num_chunks = len(doc_indices)

        # Mock API returns scores for ALL chunks (simulating disabled API-level top_n)
        # Give different scores to ensure doc 0 gets highest, doc 1 second, etc.
        # Assign scores based on original document index (lower doc index = higher score)
        mock_chunk_scores = []
        for i in range(num_chunks):
            original_doc = doc_indices[i]
            # Higher score for lower doc index, with small variation per chunk
            base_score = 0.9 - (original_doc * 0.1)
            mock_chunk_scores.append({"index": i, "relevance_score": base_score})

        mock_response = Mock()
        mock_response.status = 200
        mock_response.json = AsyncMock(return_value={"results": mock_chunk_scores})
        mock_response.request_info = None
        mock_response.history = None
        mock_response.headers = {}
        mock_response.__aenter__ = AsyncMock(return_value=mock_response)
        mock_response.__aexit__ = AsyncMock(return_value=None)

        mock_session = Mock()
        mock_session.post = Mock(return_value=mock_response)
        mock_session.__aenter__ = AsyncMock(return_value=mock_session)
        mock_session.__aexit__ = AsyncMock(return_value=None)

        with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
            result = await cohere_rerank(
                query=query,
                documents=long_docs,
                api_key="test-key",
                base_url="http://test.com/rerank",
                enable_chunking=True,
                max_tokens_per_doc=50,  # Match chunking above
                top_n=3,  # Request top 3 documents
            )

            # Verify: should get exactly 3 documents (not unlimited chunks)
            assert len(result) == 3
            # All results should have valid document indices (0-4)
            assert all(0 <= r["index"] < 5 for r in result)
            # Results should be sorted by score (descending)
            assert all(
                result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
                for i in range(len(result) - 1)
            )
            # The top 3 docs should be 0, 1, 2 (highest scores)
            result_indices = [r["index"] for r in result]
            assert set(result_indices) == {0, 1, 2}

    @pytest.mark.asyncio
    async def test_api_receives_no_top_n_when_chunking_enabled(self):
        """
        Test that the API request does NOT include top_n when chunking is enabled.

        This ensures all chunk scores are retrieved for proper aggregation.
        """
        documents = [" ".join([f"word{i}" for i in range(100)]), "short doc"]
        query = "test query"

        captured_payload = {}

        mock_response = Mock()
        mock_response.status = 200
        mock_response.json = AsyncMock(
            return_value={
                "results": [
                    {"index": 0, "relevance_score": 0.9},
                    {"index": 1, "relevance_score": 0.8},
                    {"index": 2, "relevance_score": 0.7},
                ]
            }
        )
        mock_response.request_info = None
        mock_response.history = None
        mock_response.headers = {}
        mock_response.__aenter__ = AsyncMock(return_value=mock_response)
        mock_response.__aexit__ = AsyncMock(return_value=None)

        def capture_post(*args, **kwargs):
            captured_payload.update(kwargs.get("json", {}))
            return mock_response

        mock_session = Mock()
        mock_session.post = Mock(side_effect=capture_post)
        mock_session.__aenter__ = AsyncMock(return_value=mock_session)
        mock_session.__aexit__ = AsyncMock(return_value=None)

        with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
            await cohere_rerank(
                query=query,
                documents=documents,
                api_key="test-key",
                base_url="http://test.com/rerank",
                enable_chunking=True,
                max_tokens_per_doc=30,
                top_n=1,  # User wants top 1 document
            )

            # Verify: API payload should NOT have top_n (disabled for chunking)
            assert "top_n" not in captured_payload

    @pytest.mark.asyncio
    async def test_top_n_not_modified_when_chunking_disabled(self):
        """
        Test that top_n is passed through to API when chunking is disabled.
        """
        documents = ["doc1", "doc2"]
        query = "test query"

        captured_payload = {}

        mock_response = Mock()
        mock_response.status = 200
        mock_response.json = AsyncMock(
            return_value={
                "results": [
                    {"index": 0, "relevance_score": 0.9},
                ]
            }
        )
        mock_response.request_info = None
        mock_response.history = None
        mock_response.headers = {}
        mock_response.__aenter__ = AsyncMock(return_value=mock_response)
        mock_response.__aexit__ = AsyncMock(return_value=None)

        def capture_post(*args, **kwargs):
            captured_payload.update(kwargs.get("json", {}))
            return mock_response

        mock_session = Mock()
        mock_session.post = Mock(side_effect=capture_post)
        mock_session.__aenter__ = AsyncMock(return_value=mock_session)
        mock_session.__aexit__ = AsyncMock(return_value=None)

        with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
            await cohere_rerank(
                query=query,
                documents=documents,
                api_key="test-key",
                base_url="http://test.com/rerank",
                enable_chunking=False,  # Chunking disabled
                top_n=1,
            )

            # Verify: API payload should have top_n when chunking is disabled
            assert captured_payload.get("top_n") == 1


@pytest.mark.offline
class TestCohereRerankChunking:
    """Integration tests for cohere_rerank with chunking enabled"""

    @pytest.mark.asyncio
    async def test_cohere_rerank_with_chunking_disabled(self):
        """Test that chunking can be disabled"""
        documents = ["doc1", "doc2"]
        query = "test query"

        # Mock the generic_rerank_api
        with patch(
            "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
        ) as mock_api:
            mock_api.return_value = [
                {"index": 0, "relevance_score": 0.9},
                {"index": 1, "relevance_score": 0.7},
            ]

            result = await cohere_rerank(
                query=query,
                documents=documents,
                api_key="test-key",
                enable_chunking=False,
                max_tokens_per_doc=100,
            )

            # Verify generic_rerank_api was called with correct parameters
            mock_api.assert_called_once()
            call_kwargs = mock_api.call_args[1]
            assert call_kwargs["enable_chunking"] is False
            assert call_kwargs["max_tokens_per_doc"] == 100
            # Result should mirror mocked scores
            assert len(result) == 2
            assert result[0]["index"] == 0
            assert result[0]["relevance_score"] == 0.9
            assert result[1]["index"] == 1
            assert result[1]["relevance_score"] == 0.7

    @pytest.mark.asyncio
    async def test_cohere_rerank_with_chunking_enabled(self):
        """Test that chunking parameters are passed through"""
        documents = ["doc1", "doc2"]
        query = "test query"

        with patch(
            "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
        ) as mock_api:
            mock_api.return_value = [
                {"index": 0, "relevance_score": 0.9},
                {"index": 1, "relevance_score": 0.7},
            ]

            result = await cohere_rerank(
                query=query,
                documents=documents,
                api_key="test-key",
                enable_chunking=True,
                max_tokens_per_doc=480,
            )

            # Verify parameters were passed
            call_kwargs = mock_api.call_args[1]
            assert call_kwargs["enable_chunking"] is True
            assert call_kwargs["max_tokens_per_doc"] == 480
            # Result should mirror mocked scores
            assert len(result) == 2
            assert result[0]["index"] == 0
            assert result[0]["relevance_score"] == 0.9
            assert result[1]["index"] == 1
            assert result[1]["relevance_score"] == 0.7

    @pytest.mark.asyncio
    async def test_cohere_rerank_default_parameters(self):
        """Test default parameter values for cohere_rerank"""
        documents = ["doc1"]
        query = "test"

        with patch(
            "lightrag.rerank.generic_rerank_api", new_callable=AsyncMock
        ) as mock_api:
            mock_api.return_value = [{"index": 0, "relevance_score": 0.9}]

            result = await cohere_rerank(
                query=query, documents=documents, api_key="test-key"
            )

            # Verify default values
            call_kwargs = mock_api.call_args[1]
            assert call_kwargs["enable_chunking"] is False
            assert call_kwargs["max_tokens_per_doc"] == 4096
            assert call_kwargs["model"] == "rerank-v3.5"
            # Result should mirror mocked scores
            assert len(result) == 1
            assert result[0]["index"] == 0
            assert result[0]["relevance_score"] == 0.9


@pytest.mark.offline
class TestEndToEndChunking:
    """End-to-end tests for chunking workflow"""

    @pytest.mark.asyncio
    async def test_end_to_end_chunking_workflow(self):
        """Test complete chunking workflow from documents to aggregated results"""
        # Create documents where first one needs chunking
        long_doc = " ".join([f"word{i}" for i in range(100)])
        documents = [long_doc, "short doc"]
        query = "test query"

        # Mock the HTTP call inside generic_rerank_api
        mock_response = Mock()
        mock_response.status = 200
        mock_response.json = AsyncMock(
            return_value={
                "results": [
                    {"index": 0, "relevance_score": 0.5},  # chunk 0 from doc 0
                    {"index": 1, "relevance_score": 0.8},  # chunk 1 from doc 0
                    {"index": 2, "relevance_score": 0.6},  # chunk 2 from doc 0
                    {"index": 3, "relevance_score": 0.7},  # doc 1 (short)
                ]
            }
        )
        mock_response.request_info = None
        mock_response.history = None
        mock_response.headers = {}
        # Make mock_response an async context manager (for `async with session.post() as response`)
        mock_response.__aenter__ = AsyncMock(return_value=mock_response)
        mock_response.__aexit__ = AsyncMock(return_value=None)

        mock_session = Mock()
        # session.post() returns an async context manager, so return mock_response which is now one
        mock_session.post = Mock(return_value=mock_response)
        mock_session.__aenter__ = AsyncMock(return_value=mock_session)
        mock_session.__aexit__ = AsyncMock(return_value=None)

        with patch("lightrag.rerank.aiohttp.ClientSession", return_value=mock_session):
            result = await cohere_rerank(
                query=query,
                documents=documents,
                api_key="test-key",
                base_url="http://test.com/rerank",
                enable_chunking=True,
                max_tokens_per_doc=30,  # Force chunking of long doc
            )

            # Should get 2 results (one per original document)
            # The long doc's chunks should be aggregated
            assert len(result) <= len(documents)
            # Results should be sorted by score
            assert all(
                result[i]["relevance_score"] >= result[i + 1]["relevance_score"]
                for i in range(len(result) - 1)
            )
